Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: add sparse tensor support for ngcf_conv and lightgcn_conv #75

Merged
merged 3 commits into from
Oct 23, 2023

Conversation

downeykking
Copy link
Contributor

为ngcf conv层和lightgcn conv层添加了稀疏矩阵支持,性能基本上维持不变,在测试结果上显存占用变为1/5,提速5x。

对general recommender下面的所有方法都进行了适配(如果用到了ngcf conv或lightgcn conv),所以基于这几个backbone的模型都可以实现提速。为了更好兼容现有代码逻辑,需要手动设置参数才会开启加速。
assert config["enable_sparse"] in [True, False, None]

具体使用方法:

需要安装torch_sparse

设置新参数enable_sparse==True来开启稀疏矩阵支持

如果开启了enable_sparse==True但torch_sparse不可用,会使用原来的dense edge_index

序列推荐其实也有卷积层,也可以进行加速,但是我之前做序列推荐做的比较少,我有时间再研究一下,当作一个todo计划 :)

主要的测试结果:
对于NGCF,15个epoch下,ml-1m数据集
Sparse:占用显存0.18G,训练总时间33.85s
Origin:占用显存1.01G,训练总时间155.79s

Sparse:
best valid : OrderedDict([('recall@10', 0.1096), ('recall@20', 0.1819), ('recall@50', 0.3257), ('mrr@10', 0.2796), ('mrr@20', 0.2888), ('mrr@50', 0.2933), ('ndcg@10', 0.1474), ('ndcg@20', 0.1628), ('ndcg@50', 0.2071), ('hit@10', 0.5884), ('hit@20', 0.7201), ('hit@50', 0.8541), ('precision@10', 0.1162), ('precision@20', 0.0988), ('precision@50', 0.0732)])
test result: OrderedDict([('recall@10', 0.124), ('recall@20', 0.1964), ('recall@50', 0.3377), ('mrr@10', 0.3341), ('mrr@20', 0.3422), ('mrr@50', 0.346), ('ndcg@10', 0.1817), ('ndcg@20', 0.1901), ('ndcg@50', 0.2296), ('hit@10', 0.6237), ('hit@20', 0.7398), ('hit@50', 0.8548), ('precision@10', 0.1429), ('precision@20', 0.1152), ('precision@50', 0.0815)])

Origin:
best valid : OrderedDict([('recall@10', 0.1085), ('recall@20', 0.1782), ('recall@50', 0.3184), ('mrr@10', 0.2797), ('mrr@20', 0.2882), ('mrr@50', 0.2927), ('ndcg@10', 0.1456), ('ndcg@20', 0.1603), ('ndcg@50', 0.2039), ('hit@10', 0.5871), ('hit@20', 0.7095), ('hit@50', 0.8456), ('precision@10', 0.1149), ('precision@20', 0.098), ('precision@50', 0.0731)])
test result: OrderedDict([('recall@10', 0.119), ('recall@20', 0.1906), ('recall@50', 0.3291), ('mrr@10', 0.3266), ('mrr@20', 0.3349), ('mrr@50', 0.3386), ('ndcg@10', 0.176), ('ndcg@20', 0.1848), ('ndcg@50', 0.2235), ('hit@10', 0.6159), ('hit@20', 0.7339), ('hit@50', 0.8466), ('precision@10', 0.1391), ('precision@20', 0.1136), ('precision@50', 0.0803)])

对于LightGCN,15个epoch下,ml-1m数据集
Sparse:占用显存0.17G,训练总时间23.99s
Origin:占用显存1.02G,训练总时间102.95s

Sparse:
best valid : OrderedDict([('recall@10', 0.0891), ('recall@20', 0.1434), ('recall@50', 0.25), ('mrr@10', 0.25), ('mrr@20', 0.2587), ('mrr@50', 0.2632), ('ndcg@10', 0.1242), ('ndcg@20', 0.1326), ('ndcg@50', 0.1641), ('hit@10', 0.5214), ('hit@20', 0.6464), ('hit@50', 0.7824), ('precision@10', 0.0973), ('precision@20', 0.08), ('precision@50', 0.0583)])
test result: OrderedDict([('recall@10', 0.0956), ('recall@20', 0.1526), ('recall@50', 0.2579), ('mrr@10', 0.2876), ('mrr@20', 0.296), ('mrr@50', 0.3002), ('ndcg@10', 0.1452), ('ndcg@20', 0.1502), ('ndcg@50', 0.1783), ('hit@10', 0.5424), ('hit@20', 0.663), ('hit@50', 0.7912), ('precision@10', 0.1124), ('precision@20', 0.0905), ('precision@50', 0.0633)])

Origin:
best valid : OrderedDict([('recall@10', 0.0883), ('recall@20', 0.1433), ('recall@50', 0.2494), ('mrr@10', 0.2507), ('mrr@20', 0.2595), ('mrr@50', 0.2639), ('ndcg@10', 0.1242), ('ndcg@20', 0.1325), ('ndcg@50', 0.164), ('hit@10', 0.519), ('hit@20', 0.6454), ('hit@50', 0.7814), ('precision@10', 0.0974), ('precision@20', 0.08), ('precision@50', 0.0585)])
test result: OrderedDict([('recall@10', 0.0955), ('recall@20', 0.1524), ('recall@50', 0.258), ('mrr@10', 0.2864), ('mrr@20', 0.2949), ('mrr@50', 0.2992), ('ndcg@10', 0.1451), ('ndcg@20', 0.1499), ('ndcg@50', 0.1782), ('hit@10', 0.5406), ('hit@20', 0.6605), ('hit@50', 0.7908), ('precision@10', 0.1129), ('precision@20', 0.0905), ('precision@50', 0.0635)])

对于SGL,15个epoch下,ml-1m数据集
Sparse:占用显存0.60G,训练总时间70.84s
Origin:占用显存1.09G,训练总时间382.86s

Sparse:
best valid : OrderedDict([('recall@10', 0.1299), ('recall@20', 0.1991), ('recall@50', 0.3302), ('mrr@10', 0.3238), ('mrr@20', 0.3315), ('mrr@50', 0.3348), ('ndcg@10', 0.1721), ('ndcg@20', 0.184), ('ndcg@50', 0.2245), ('hit@10', 0.638), ('hit@20', 0.7461), ('hit@50', 0.846), ('precision@10', 0.1311), ('precision@20', 0.1058), ('precision@50', 0.0753)])
test result: OrderedDict([('recall@10', 0.1414), ('recall@20', 0.215), ('recall@50', 0.3435), ('mrr@10', 0.3893), ('mrr@20', 0.3962), ('mrr@50', 0.3992), ('ndcg@10', 0.2085), ('ndcg@20', 0.2144), ('ndcg@50', 0.2491), ('hit@10', 0.6661), ('hit@20', 0.7645), ('hit@50', 0.8531), ('precision@10', 0.1559), ('precision@20', 0.1221), ('precision@50', 0.0824)])

Origin:
best valid : OrderedDict([('recall@10', 0.1263), ('recall@20', 0.1969), ('recall@50', 0.3253), ('mrr@10', 0.3225), ('mrr@20', 0.3304), ('mrr@50', 0.3337), ('ndcg@10', 0.1699), ('ndcg@20', 0.1824), ('ndcg@50', 0.2219), ('hit@10', 0.6308), ('hit@20', 0.7433), ('hit@50', 0.8432), ('precision@10', 0.1299), ('precision@20', 0.1054), ('precision@50', 0.0748)])
test result: OrderedDict([('recall@10', 0.1403), ('recall@20', 0.2102), ('recall@50', 0.3398), ('mrr@10', 0.3893), ('mrr@20', 0.3961), ('mrr@50', 0.3992), ('ndcg@10', 0.2076), ('ndcg@20', 0.2124), ('ndcg@50', 0.2473), ('hit@10', 0.6608), ('hit@20', 0.7567), ('hit@50', 0.8476), ('precision@10', 0.1554), ('precision@20', 0.1214), ('precision@50', 0.0823)])

@hyp1231
Copy link
Member

hyp1231 commented Oct 23, 2023

orz 太猛了,新实现很优雅,学到很多!想请教一下提高效率的原因是因为 torch_sparse 底层封装了更高效的稀疏矩阵算子吗?因为感觉理论上感觉数据组织形式和原实现相似,但是原实现的计算部分可能都是 naive 的 python 运算?

@hyp1231 hyp1231 merged commit 979b219 into RUCAIBox:main Oct 23, 2023
1 check passed
@downeykking
Copy link
Contributor Author

downeykking commented Oct 24, 2023

orz 太猛了,新实现很优雅,学到很多!想请教一下提高效率的原因是因为 torch_sparse 底层封装了更高效的稀疏矩阵算子吗?因为感觉理论上感觉数据组织形式和原实现相似,但是原实现的计算部分可能都是 naive 的 python 运算?

在我个人的理解中,原始实现是进行两步骤,分别是message,aggregate,其中aggregate是通过torch_scatter实现的,从source节点聚合信息到target节点,比如这样:

        row, col = edge_index
        # x_j为聚合后的x_j,按照0-row.max().item()+1顺序排列
        x_j = scatter.scatter(x[col], row, dim=0, dim_size=x.size(0), reduce='mean')

所以这时候要显式提供edge_index.size(1)x,(这个数通常还是非常大,尽管作者对torch_scatter也用c++优化了速度,但显存和速度还是相对不足),并在这个基础上做聚合。
而torch_sparse可以利用message_and_aggregate变成一步运算,并且聚合的时候也用的是稀疏矩阵和稠密矩阵的乘法,
x_j = matmul(adj_t, x, reduce=self.aggr)
作者写的torch_sparse中用c++重写了很多算子,实现了基于GPU的稀疏矩阵乘法的快速前向和后向传递。

参考:https://pytorch-geometric.readthedocs.io/en/latest/advanced/sparse_tensor.html

@hyp1231
Copy link
Member

hyp1231 commented Oct 24, 2023

原来是这样,学到了,感谢解答!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants